#!/usr/bin/env python
#!encoding:utf-8

import time
import struct
import Config as cf
#import numpy as np
from datetime import datetime
from tkinter import *
import tkinter.filedialog as filedialog
from sklearn.externals import joblib
from sklearn.neighbors import NearestNeighbors

TRAIN = []    #全局变量,用于存放随时变化的训练集


"""
function: 获取当前毫秒级时间
return: 当前时间
"""
def get_time_stamp():
    ct = time.time()
    local_time = time.localtime(ct)
    data_head = time.strftime("%Y-%m-%d %H:%M:%S", local_time)
    data_secs = (ct - long(ct)) * 1000
    return "%s.%03d" % (data_head, data_secs)


"""
function: 完整读取二进制文件
parameter: filename 文件路径
parameter: nx 文件行数
parameter: nz 文件列数
return: dataList 返回文件存储的信息
"""
def loadBinFiles(filename, nx, nz):
    f = open(filename, "rb")
    dataList = []
    for i in range(nx):
        tmp = []
        for j in range(nz):
            data = f.read(4)
            elem = struct.unpack("f", data)[0]
            if j != 0:
                tmp.append(elem)
        dataList.append(tmp)
    f.close()
    return dataList


"""
function: 数据部分读取处理,调用外方法实现最近领查找
parameter: filename 文件路径
parameter: nx 文件行数
parameter: nz 文件列数
return: distances2 测试集最近点的距离值
return: indices2 测试集最近点的ID
"""
def loadPartData(filename, testing, nx ,nz):
    Count = 1
    timeSum = timeSum1 = 0.0
    f = open(filename, "rb")
    for i in range(nx):
        tmp = []
        for j in range(nz):
            data = f.read(4)
            elem = struct.unpack("f", data)[0]
            if j != 0:
                tmp.append(elem)
        TRAIN.append(tmp)
        # 样本较大,每读取1w次建树查找
        if (i+1)% cf.BALL_TREE_SZIE == 0:
            global TRAIN
            if Count%2 != 0:
                distances, indices, TRAIN, timesum, timesum1 = ballTree(TRAIN, testing)
                indices = IDRestore(indices, Count)
            else:
                distances1, indices1, TRAIN, timesum, timesum1 = ballTree(TRAIN, testing)
                indices1 = IDRestore(indices1, Count)
                if Count == 2:
                    distances2, indices2 = sizeCompare(distances,indices,distances1,indices1)
                else:
                    distances3, indices3 = sizeCompare(distances,indices,distances1,indices1)
                    distances2, indices2 = sizeCompare(distances2, indices2,distances3, indices3)
            Count = Count + 1
            timeSum = timeSum + timesum;timeSum1 = timeSum1 + timesum1
            
    f.close()
    return distances2, indices2, timeSum, timeSum1


"""
function: 计算时间差
parameter: startTime 开始时间
parameter: endTime 结束时间
return: 两者时间差
"""
def timeDifference(startTime, endTime):
    format = '%Y-%m-%d %H:%M:%S.%f'
    a = datetime.strptime(startTime, format)
    b = datetime.strptime(endTime, format)
    startTime = time.mktime(a.timetuple())*1000 + a.microsecond/1000
    endTime = time.mktime(b.timetuple())*1000 + b.microsecond/1000
    return endTime-startTime


"""
function: 根据相应训练集建立球树,查找测试集的最近值和ID
parameter: training 训练集
parameter: count 建树次数
return: distances 测试集的最近距离值
return: indices 对应训练集的位置
retrun: training 训练集
return: timesum 建树所需时间
return: timesum1 查找时间
"""
def ballTree(training, testing):
    distances = [];indices = []
    print '建树时间'
    time1 = get_time_stamp();print time1
    nbrs = NearestNeighbors(n_neighbors=2, algorithm="ball_tree", metric='euclidean').fit(training)    # 开始建立球树
    time2 = get_time_stamp();print time2
    timesum = timeDifference(time1, time2)
    if cf.SAVE_MODEL:
        joblib.dump(nbrs, cf.MODEL_WAREHOUSE+'Tree'+str(count)+'.pkl')    # 保存模型
    print '检索时间'
    time3 = get_time_stamp();print time3
    for i in range(len(testing)):
        dis, ind = nbrs.kneighbors([testing[i]], 1)    # 开始检索测试集信息
        distances.append(float(dis));indices.append(int(ind))
    time4 = get_time_stamp();print time4
    timesum1 = timeDifference(time3, time4)
    training = []
    return distances, indices, training, timesum, timesum1


"""
function: 因为分割读取导致建树时ID改变,修改原ID
parameter: indices 数据集中ID值
parameter: count 建树次数
return: indices 修改后的ID值
"""
def IDRestore(indices, count):
    for i in range(len(indices)):
        # 下变从0开始使得count减1,只保留测试集数据信息
        indices[i] = indices[i] + 10000*(count-1)
    return indices


"""
function: 比较大小
parameter: distances 距离值集合
parameter: indices ID值集合
parameter: distances1 距离值集合
parameter: indices1 ID值集合
return: distancesc 两者较小的距离值集合
return: indices 两者较小的距离对应ID值集合
"""
def sizeCompare(distances, indices, distances1, indices1):
    #distances = minimun(distances, distances1)
    #indices = minimun(indices, indices1)
    for i in range(len(indices)):
        if distances[i] > distances1[i]:
            distances[i] = distances1[i]
            indices[i] = indices1[i]
    '''
    for i in range(len(indices)):
        for j in range(len(indices[0])):
            if distances[i][j] > distances1[i][j]:
                distances[i][j] = distances1[i][j]
                indices[i][j] = indices1[i][j]
    '''
    return distances,indices


def callback():
    entry.delete(0,END)    # 清空entry里面的内容
    #调用filedialog模块的askdirectory()函数去打开文件夹
    global filepath
    filepath = filedialog.askopenfilename() 
    if filepath:
        entry.insert(0,filepath)    # 将选择好的路径加入到entry里面
    print filepath


def callback1():
    entry1.delete(0,END)
    global filepath1
    filepath1 = filedialog.askopenfilename() 
    if filepath1:
        entry1.insert(0,filepath1)
    print filepath1


def callback2():
    filename = filepath1    #cf.TESTING_PATH          # 100,1025
    filename1 = filepath    #cf.TRAINING_PATH        # 1000000,1025
    print filename,filename1
    print '开始时间: '+get_time_stamp()
    testing = loadBinFiles(filename,cf.TESTING_SIZE,cf.DATA_DIMENSION)
    ##TRAIN.extend(loadBinFiles(filename,cf.TESTING_SIZE,cf.DATA_DIMENSION))    # 用extend将测试集存入TRAIN
    distances, indices, timeSum, timeSum1 = loadPartData(filename1, testing, cf.TRAINING_SIZE, cf.DATA_DIMENSION)
    #print(indices);print(distances)
    print '结束时间: '+get_time_stamp()
    print '建树时间为:', timeSum/1000.0 , '秒'
    print '检索时间为:', timeSum1/1000.0 , '秒'

if __name__ == "__main__":
    root = Tk()
    root.title("标题狗")
    root.geometry("400x400")
    
    entry = Entry(root, width=60)
    entry.grid(sticky=W+N, row=0, column=0, columnspan=4, padx=5, pady=5)
    button = Button(root,text="选择训练集文件",command=callback)
    button.grid(sticky=W+N, row=1, column=0, padx=5, pady=5)
    
    entry1 = Entry(root, width=60)
    entry1.grid(sticky=W+N, row=2, column=0, columnspan=4, padx=5, pady=5)
    button1 = Button(root,text="选择测试集文件",command=callback1)
    button1.grid(sticky=W+N, row=3, column=0, padx=5, pady=5)

    button2 = Button(root,text="开始运行",command=callback2)
    button2.grid(sticky=W+N, row=4, column=0, padx=5, pady=5)
    root.mainloop()



"""
#思路：
1 先把这10w个存一个list    <***>
2 然后先一小部分读数据,把10w数据和这读取的一部分做一个训练集    <***>
3 去建树查找10w最小值,保存最小值和对应id    <***>
4 对list分片删除读取的数据    <***>
5 将本次最近结果和上次做比较保留最小值和对应的id    <***>
6 循环上述步骤    <***>
"""

